# Direct outputs not eligible for public release

# E.g., to run
# nohup R CMD BATCH --no-save --no-restore '--args outcomes=c("firm_move","firm_move_up","firm_move_down")' & #nolint

args <- (commandArgs(TRUE))
if (length(args) > 0) {
    for (i in 1:length(args)) {
        eval(parse(text = args[[i]]))
    }
}

# GLOBAL SETTINGS --------------------------------------------------------------

options(
    scipen = 999,
    digits = 16,
    max.print = .Machine$integer.max,
    show.error.locations = TRUE,
    warn = 1
)

RNGkind("L'Ecuyer-CMRG")
seed <- 818675309L
set.seed(seed) # setting main seed

# PACKAGES ---------------------------------------------------------------------
library(data.table)
library(Matrix)
library(zoo)
library(multiwayvcov)
library(lmtest)
library(checkmate)
library(futile.logger)
library(lfe)

library(remotes)
remotes::install_github("setzler/eventStudy/eventStudy")
library(eventStudy) # for DiD-IV

# PACKAGE SETTINGS -------------------------------------------------------------

# data.table
setDTthreads(threads = 1L)
options(datatable.print.class = TRUE, datatable.print.keys = TRUE)
# so that printing the data.table also shows the variable type on top

# BEGIN FILE -------------------------------------------------------------------

base::source("~/code/0-utility-functions/wald_es.R", local = TRUE)

MakeWealthFirmMoveEffects <- function(outcome) {

    # Read in lottery panel data
    lottery_panel <- readRDS("~/population-panel-data/lottery_panel_data.rds")

    # Fill in some zeros
    lottery_panel[is.na(db_w2_wages), db_w2_wages := 0]
    lottery_panel[is.na(self_empl_income), self_empl_income := 0]

    # For AGI (used to construct quartiles)
    lottery_panel[is.na(per_adult_adjgross), per_adult_adjgross := 0]

    # Each treated cohort will be restricted to age 21-64 in year 0 and
    # only control units age 21-64 in the same calendar year will be used
    lottery_panel[
        ,
        age_case := as.logical(between(age, 21, 64, incbounds = TRUE))
    ]

    # Will focus on the population initially in paid employment
    lottery_panel[, paid_employment := as.logical(main_job_W2 == 1)]

    # Read in firm panel with firm IDs for top-paying W2 firm
    firm_tin_dt <- readRDS("~/population-panel-data/top-w2-firm-panel.rds")
    firm_tin_dt[, payer_tin := as.integer(payer_tin)]

    # Read in the time-invariant firm-level mean wages
    # key variable: mean_weighted_resid (average mean wages net of time effects)
    firm_wages <-
        readRDS(
            "~/population-crossection-data/firm_mean_earnings_time_invariant.rds"
        )
    firm_wages[, payer_tin := as.integer(payer_tin)]

    firm_tin_dt <-
        merge(
            firm_tin_dt,
            firm_wages,
            by = "payer_tin",
            all.x = TRUE,
            sort = FALSE
        )
    firm_wages <- NULL
    rm(firm_wages)

    # Merge on firm info
    lottery_panel <-
        merge(
            lottery_panel,
            firm_tin_dt,
            by = c("tin", "tax_yr"),
            all.x = TRUE,
            sort = FALSE
        )
    firm_tin_dt <- NULL
    rm(firm_tin_dt)

    # Filling in 0s on the firm-info just so observations don't get dropped
    lottery_panel[is.na(payer_tin), payer_tin:= 0]
    lottery_panel[is.na(mean_weighted_resid), mean_weighted_resid := 0]

    # Have not actually constructed the outcome yet
    setorderv(lottery_panel, c("tin", "tax_yr"))
    lottery_panel[, (outcome) := 0L]

    # Constrain to paid employed (and then later in omitted_event_time too)
    lottery_panel <- lottery_panel[main_job_W2 == 1]
    add_unit_fes <- TRUE

    # Setup timing -- omit -1 and use -2 to construct firm-switch variables
    anticipation <- 0
    omitted_event_time <- -1
    cond_pos_var_time <- -2

    # Produce deliberately aggregated estimates
    collapse_table <-
        data.table(
            a = c("one_to_two", "three_to_five", "post_avg"),
            b = c(list(1:2), list(3:5), list(1:5))
        )

    # Run stacked event-study regression
    did_results_temp <-
        Wald_ES2(
            long_data = copy(lottery_panel),
            outcomevar = outcome,
            unit_var = "tin",
            cal_time_var = "tax_yr",
            onset_time_var = "win_yr",
            cluster_vars = "tin",
            omitted_event_time = omitted_event_time,
            discrete_covars = "age",
            control_subset_var = "age_case",
            control_subset_event_time = 0,
            treated_subset_var = "age_case",
            treated_subset_event_time = 0,
            control_subset_var2 = "paid_employment",
            control_subset_event_time2 = cond_pos_var_time,
            treated_subset_var2 = "paid_employment",
            treated_subset_event_time2 = cond_pos_var_time,
            heterogeneous_only = TRUE,
            anticipation = anticipation,
            endog_var = "L_multiperiod",
            calculate_collapse_estimates = TRUE,
            collapse_inputs = collapse_table,
            add_unit_fes = TRUE
        )

    did_results_q1_temp <-
        Wald_ES2(
            long_data = copy(lottery_panel),
            outcomevar = outcome,
            unit_var = "tin",
            cal_time_var = "tax_yr",
            onset_time_var = "win_yr",
            cluster_vars = "tin",
            omitted_event_time = omitted_event_time,
            discrete_covars = "age",
            control_subset_var = "age_case",
            control_subset_event_time = 0,
            treated_subset_var = "age_case",
            treated_subset_event_time = 0,
            control_subset_var2 = "paid_employment",
            control_subset_event_time2 = cond_pos_var_time,
            treated_subset_var2 = "paid_employment",
            treated_subset_event_time2 = cond_pos_var_time,
            heterogeneous_only = TRUE,
            anticipation = anticipation,
            endog_var = "L_multiperiod",
            calculate_collapse_estimates = TRUE,
            collapse_inputs = collapse_table,
            add_unit_fes = TRUE,
            ntile_var = "per_adult_adjgross",
            ntile_event_time = omitted_event_time,
            ntiles = 4,
            ntile_var_value = 1,
            ntile_avg = FALSE
        )

    did_results_q4_temp <-
        Wald_ES2(
            long_data = copy(lottery_panel),
            outcomevar = outcome,
            unit_var = "tin",
            cal_time_var = "tax_yr",
            onset_time_var = "win_yr",
            cluster_vars = "tin",
            omitted_event_time = omitted_event_time,
            discrete_covars = "age",
            control_subset_var = "age_case",
            control_subset_event_time = 0,
            treated_subset_var = "age_case",
            treated_subset_event_time = 0,
            control_subset_var2 = "paid_employment",
            control_subset_event_time2 = cond_pos_var_time,
            treated_subset_var2 = "paid_employment",
            treated_subset_event_time2 = cond_pos_var_time,
            heterogeneous_only = TRUE,
            anticipation = anticipation,
            endog_var = "L_multiperiod",
            calculate_collapse_estimates = TRUE,
            collapse_inputs = collapse_table,
            add_unit_fes = TRUE,
            ntile_var = "per_adult_adjgross",
            ntile_event_time = omitted_event_time,
            ntiles = 4,
            ntile_var_value = 4,
            ntile_avg = FALSE
        )

    # Above results contain more than is needed for this step, so focus them
    # with a quick subsetting function

    MakeMainFirmMoveEstimates <- function(did_dt) {

        # reduced-form / event-study estimates
        doutcome_dt <-
            setDT(did_dt[[1]])[
                rn == "att" &
                    model == "reduced_form" &
                    ref_onset_time == "Cohort-Weighted",
                .(ref_event_time, estimate, cluster_se, model)
            ]

        # post-period avg collapsed estimate
        doutcome_dl_dt_collapsed <-
            setDT(did_dt[[1]])[
                rn == "att" &
                    model == "ratio" &
                    ref_onset_time == "Cohort-Weighted + Collapsed" &
                    grouping %in% collapse_table$a,
                .(ref_event_time, estimate, cluster_se, grouping, model)
            ]

        # Will use a 97, 98, 99 convention to get
        # the right order of results in tables
        doutcome_dl_dt_collapsed[
            grouping %in% c("one_to_two", "first_two"),
            ref_event_time := 97
        ]
        doutcome_dl_dt_collapsed[
            grouping == "three_to_five",
            ref_event_time := 98
        ]
        doutcome_dl_dt_collapsed[grouping == "post_avg", ref_event_time := 99]
        doutcome_dl_dt_collapsed[, grouping := NULL]

        # Counterfactual untreated mean of the treated group
        # if E[y(1)-y(0)|D=1] = E[y1-y0|D=1] - E[y1-y0|D=0] (= Reduced form DiD)
        # then E[y(0)|D = 1] = E[y(1)] - RF DiD
        did_cohort_et <-
            setDT(did_dt[[1]])[
                model == "reduced_form" & rn == "catt",
                .(
                    ref_onset_time,
                    ref_event_time,
                    estimate,
                    catt_treated_unique_units
                )
            ]
        setnames(did_cohort_et, "estimate", "did")

        post_mean <-
            setDT(did_dt[[1]])[
                model == "reduced_form" &
                    rn == "treatment_means" &
                    treated == 1,
                .(
                    ref_onset_time,
                    ref_event_time,
                    estimate
                )
            ]
        setnames(post_mean, "estimate", "post_mean")

        cfactual_et <-
            merge(
                did_cohort_et,
                post_mean,
                by = c("ref_onset_time", "ref_event_time"),
                all.x = TRUE,
                sort = FALSE
            )
        # note: of course, omitted_event_time isn't in the above

        post_mean <- NULL
        did_cohort_et <- NULL
        rm(post_mean, did_cohort_et)

        cfactual_et[, estimate := (post_mean - did)]
        cfactual_et[, c("post_mean", "did") := NULL]

        # Will also want cohort-weighted event-time versions
        cfactual_et[, ref_onset_time := as.character(ref_onset_time)]

        cfactual_t <-
            cfactual_et[
                ,
                .(
                    ref_onset_time = "Cohort-Weighted",
                    estimate =
                        weighted.mean(
                            x = estimate,
                            w = catt_treated_unique_units
                        )
                ),
                by = .(ref_event_time)
            ]

        cfactual_collapsed <-
            cfactual_t[
                ref_event_time %in% c(1:5),
                .(
                    ref_onset_time = "Cohort-Weighted + Collapsed",
                    ref_event_time = 99L,
                    estimate = mean(x = estimate)
                )
            ]
        cfactual_collapsed[, cluster_se := 0]
        cfactual_collapsed[, model := "counterfactual_mean"]

        combined <-
            rbindlist(
                list(doutcome_dt, doutcome_dl_dt_collapsed, cfactual_collapsed),
                use.names = TRUE
            )

        return(combined)
    }

    did_results_agg <-
        MakeMainFirmMoveEstimates(did_dt = did_results_temp)
    did_results_agg[, income_quartile := 5L]

    did_results_q1 <-
        MakeMainFirmMoveEstimates(did_dt = did_results_q1_temp)
    did_results_q1[, income_quartile := 1L]

    did_results_q4 <-
        MakeMainFirmMoveEstimates(did_dt = did_results_q4_temp)
    did_results_q4[, income_quartile := 4L]

    did_results_temp <-
        rbindlist(
            list(
                did_results_agg,
                did_results_q1,
                did_results_q4
            ),
            use.names = TRUE
        )

    # Only need the reduced form for the aggregate estimates
    did_results <-
        did_results[
            (model == "reduced_form" & income_quartile == 5) |
                (model != "reduced_form")
        ]

    saveRDS(
        did_results,
        sprintf("~/estimation-output/wealth_effect_estimates_%s.rds", outcome) # nolint
    )

    return(outcome)
}

num_cores <- length(outcomes)
mcmapply(
    MakeWealthFirmMoveEffects,
    outcomes,
    SIMPLIFY = FALSE,
    mc.silent = FALSE,
    mc.cores = num_cores,
    mc.set.seed = TRUE
)
